import torch
import torch.nn as nn


class QuantileLoss(nn.Module):

    def __init__(self, quantiles):
        super().__init__()
        self.quantiles = quantiles

    def forward(self, preds, target):
    
        if preds.size(0) != target.size(0):
            raise ValueError(
                f"Batch size mismatch between predictions and targets. Got preds: {preds.size(0)}, target: {target.size(0)}")

        losses = []
        output_dim = target.shape[-1]
        for i, q in enumerate(self.quantiles):
            preds_i = preds[..., i * output_dim : (i + 1) * output_dim]
            errors = target - preds_i
            loss_q = torch.max((q - 1) * errors, q * errors).unsqueeze(1)
            losses.append(loss_q)

        loss = torch.mean(torch.sum(torch.cat(losses, dim=1), dim=1))
        return loss
